-
Notifications
You must be signed in to change notification settings - Fork 269
WMMA grouped conv fwd large tensor bias bnorm clamp #3595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WMMA grouped conv fwd large tensor bias bnorm clamp #3595
Conversation
7b0341d to
b5c541f
Compare
...eration/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp
Outdated
Show resolved
Hide resolved
|
Looks good overall! I have some comments:
|
...tance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp
Outdated
Show resolved
Hide resolved
23c6688 to
7306d8b
Compare
bartekxk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm please rebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds WMMA (Wave Matrix Multiply Accumulate) kernel support for grouped convolution forward operations with bias, batch normalization, and clamp activation for large tensors. The implementation targets FP16 and BF16 data types with NHWGC/GKYXC/NHWGK tensor layouts.
Changes:
- Added WMMA kernel instances for 2D and 3D grouped convolutions with bias+bnorm+clamp operations
- Updated test infrastructure to enable bias+bnorm+clamp tests on gfx9, gfx11, and gfx12 GPU targets
- Modified initialization ranges in the profiler to ensure monotone operations, improving numerical stability on RDNA3 architectures
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp | Updated function call from bias_clamp to bias_bnorm_clamp |
| test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp | Updated function call from bias_clamp to bias_bnorm_clamp |
| test/grouped_convnd_fwd_activation/CMakeLists.txt | Moved bias_bnorm_clamp tests to gfx9/11/12 target section |
| test/CMakeLists.txt | Added test to regression test list |
| profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp | New profiler entry point for bias_bnorm_clamp operation |
| profiler/src/CMakeLists.txt | Added profiler source and device instances to build |
| profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp | Changed initialization ranges to monotone values for numerical stability |
| library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/*.cpp | WMMA large tensor instances for 3D conv (f16/bf16) |
| library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/*.cpp | WMMA instances for 3D conv (f16/bf16) |
| library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt | Updated build configuration to include WMMA instances |
| library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/*.cpp | WMMA large tensor instances for 2D conv (f16/bf16) |
| library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/*.cpp | WMMA instances for 2D conv (f16/bf16) |
| library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt | Updated build configuration to include WMMA instances |
| library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc | Forward declarations for WMMA instance functions |
| library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp | Factory integration for WMMA instances |
| library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp | Generic instance templates for WMMA large tensor |
| library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp | Generic instance templates and type aliases |
| include/ck/utility/array.hpp | Added Emplace method for in-place construction |
| include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp | Refactored to use Emplace and added NumDTensor guards |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
...a/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
Outdated
Show resolved
Hide resolved
.../device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
Outdated
Show resolved
Hide resolved
35f342c to
427b259
Compare
The base branch was changed.
Following operations are added for FP16/BF16 data type and NHWGCxGKYXC layout. - grouped_conv2d_fwd_bias_bnorm_clamp - grouped_conv3d_fwd_bias_bnorm_clamp
427b259 to
41ca771
Compare
Proposed changes
Added bias bnorm clamp operation for WMMA conv fwd large tensor (FP16/BF16 data type and NHWGCxGKYXC layout).
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered